{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CDuhyGZEWPUW"
   },
   "source": [
    "# Load drive and necessary things"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1181,
     "status": "ok",
     "timestamp": 1589786273013,
     "user": {
      "displayName": "Ramin Akhtar",
      "photoUrl": "",
      "userId": "18054088260314135543"
     },
     "user_tz": -270
    },
    "id": "jlJfKefWCjWh",
    "outputId": "29ac54fd-f3ab-466e-d7bf-c7e2ee99eeb1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorFlow 1.x selected.\n"
     ]
    }
   ],
   "source": [
    "%tensorflow_version 1.x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 309
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 60979,
     "status": "ok",
     "timestamp": 1589786332830,
     "user": {
      "displayName": "Ramin Akhtar",
      "photoUrl": "",
      "userId": "18054088260314135543"
     },
     "user_tz": -270
    },
    "id": "RRBHhYHMXAG0",
    "outputId": "682ecbbb-4f5c-4c3d-d788-23ac0a7f9675"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n",
      "\n",
      "Enter your authorization code:\n",
      "··········\n",
      "Mounted at /content/drive\n",
      "/content/drive/My Drive/BE pro\n",
      "'confusion matrix_results.docx'\t\t 'model results'\n",
      " data\t\t\t\t\t 'Proposed method.ipynb'\n",
      " Deep_Learning_EEG_Classification.ipynb   __pycache__\n",
      " depression-rest-preprocessed.zip\t 'regression predict.docx'\n",
      " document.docx\t\t\t\t 'related work results'\n",
      " dreegdl\t\t\t\t 'Related works.ipynb'\n",
      "'ICA results'\t\t\t\t  results.docx\n",
      " labels.xlsx\t\t\t\t  TCN.docx\n",
      " model.gdoc\t\t\t\t  tcn.py\n",
      " model.png\n"
     ]
    }
   ],
   "source": [
    "from google.colab import drive\n",
    "drive.mount('/content/drive')\n",
    "%cd drive/My Drive/BE pro\n",
    "! ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_q6kr5H-jc-G"
   },
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from os import listdir\n",
    "from os.path import isfile, join\n",
    "from keras.backend import clear_session\n",
    "from sklearn.model_selection import train_test_split, StratifiedKFold\n",
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder, MinMaxScaler\n",
    "from keras.layers import Dense, concatenate, Dropout, Conv1D, MaxPooling1D, Flatten, LeakyReLU, Conv2D, Reshape, LSTM\n",
    "from keras.models import Input, Model\n",
    "from keras.initializers import glorot_normal\n",
    "from keras import optimizers\n",
    "import keras.backend as K\n",
    "from tcn import TCN\n",
    "from keras.callbacks import LearningRateScheduler\n",
    "from sklearn.utils.class_weight import compute_class_weight\n",
    "from numba import cuda\n",
    "from imblearn.metrics import sensitivity_specificity_support\n",
    "from sklearn.svm import SVR, SVC\n",
    "from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier\n",
    "from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import warnings\n",
    "warnings.simplefilter(action='ignore')\n",
    "from dreegdl import Config, Preprocessing, Kaggle, Training, Utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jxqNBacqVjhT"
   },
   "source": [
    "# Define functions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "fru6WSmqeYJM"
   },
   "outputs": [],
   "source": [
    "def pad_N_epochs(x,limit):\n",
    "  out = []\n",
    "  for y in x:\n",
    "    while len(y) < limit:\n",
    "      y = np.concatenate([y,y[-(limit-len(y)):]])\n",
    "    out.append(y[:limit])\n",
    "  return np.array(out)\n",
    "\n",
    "\n",
    "def crop_N_epochs(x,limit):\n",
    "  return np.array([y[:limit] for y in Utils.pad_N_epochs(x,limit)])\n",
    "\n",
    "\n",
    "def load_data():\n",
    "  ##-- load labels\n",
    "  label_path = './labels.xlsx'\n",
    "  xl_file = pd.ExcelFile(label_path)\n",
    "\n",
    "  dfs = {sheet_name: xl_file.parse(sheet_name) \n",
    "            for sheet_name in xl_file.sheet_names}\n",
    "  labels = dfs['Depression Rest']\n",
    "  del xl_file, dfs\n",
    "  ##-- load data\n",
    "  base_dir = './data/'\n",
    "  onlyfiles = [f for f in listdir(base_dir) if isfile(join(base_dir, f))] # get name of files in a folder\n",
    "  data = []\n",
    "  label = []\n",
    "  for i in onlyfiles:\n",
    "    if '_Depression_REST-epo-feat-v1' in i:\n",
    "      temp = np.load(base_dir+i, allow_pickle=True)\n",
    "      temp = pad_N_epochs(temp, 120)\n",
    "      temp = temp.reshape(temp.shape[0], temp.shape[1], -1)\n",
    "      data.append(temp)\n",
    "      Id = int(i.split('_')[0])\n",
    "      lab = labels[labels['id']==Id].BDI.values[0]*np.ones(temp.shape[0]).astype('uint8')\n",
    "      label.append(lab)\n",
    "  del labels, lab, temp\n",
    "  all_data = np.concatenate(data)\n",
    "  all_label = np.concatenate(label)\n",
    "  all_label = all_label.astype('uint8')\n",
    "  print('data shape:', all_data.shape)\n",
    "  del data, label\n",
    "  return all_data, all_label\n",
    "\n",
    "\n",
    "def change_to_4(all_label):\n",
    "  L = np.copy(all_label)\n",
    "  a = np.unique(all_label)\n",
    "  for i in a:\n",
    "      idx = np.where(all_label==i)[0]\n",
    "      if i>=0 and i<=13:\n",
    "          L[idx] = 0\n",
    "      elif i>=14 and i<=19:\n",
    "          L[idx] = 1\n",
    "      elif i>=20 and i<=28:\n",
    "          L[idx] = 2\n",
    "      else:\n",
    "          L[idx] = 3\n",
    "          \n",
    "  all_label = np.copy(L)\n",
    "  del L, a\n",
    "  return all_label\n",
    "\n",
    "\n",
    "def change_to_2(data, label, mode):\n",
    "  idx = []\n",
    "  for i in range(len(mode)):\n",
    "    idx.append(np.where(label==mode[i])[0])\n",
    "  idx = np.concatenate(idx)\n",
    "  data = data[idx]\n",
    "  label = label[idx]\n",
    "  return data, label\n",
    "\n",
    "\n",
    "def prepare_data(data, labels, mode):\n",
    "  if len(mode)==1: # 4 class classification or regression\n",
    "    if mode[0]==0: # do regression\n",
    "      print('Problem: regression')\n",
    "      return data, labels\n",
    "    else: # else it is 4 that means do 4class classification\n",
    "      print('Problem: 4 class')\n",
    "      labels = change_to_4(labels)\n",
    "      return data, labels\n",
    "  else: # that mean do 2class classification\n",
    "    print('Problem: 2 class')\n",
    "    labels = change_to_4(labels)\n",
    "    data, labels = change_to_2(data, labels, mode)\n",
    "    return data, labels\n",
    "  \n",
    "\n",
    "def DCNN(input_dim, timesteps, num_class):\n",
    "  i = Input(shape=(timesteps, input_dim))\n",
    "  h = Conv1D(filters=5, kernel_size=5, name='conv1')(i)\n",
    "  h = LeakyReLU(name='lr1')(h)\n",
    "  h = MaxPooling1D(pool_size=2, stride=2, name='mp1')(h)\n",
    "  h = Conv1D(filters=5, kernel_size=5, name='conv2')(h)\n",
    "  h = LeakyReLU(name='lr2')(h)\n",
    "  h = MaxPooling1D(pool_size=2, stride=2, name='mp2')(h)\n",
    "  h = Conv1D(filters=10, kernel_size=5, name='conv3')(h)\n",
    "  h = LeakyReLU(name='lr3')(h)\n",
    "  h = MaxPooling1D(pool_size=2, stride=2, name='mp3')(h)\n",
    "  h = Conv1D(filters=10, kernel_size=5)(h)\n",
    "  h = LeakyReLU(name='lr4')(h)\n",
    "  h = MaxPooling1D(pool_size=2, stride=2)(h)\n",
    "  h = Flatten(name='flt')(h)\n",
    "  h = Dropout(0.1, name='drp1')(h)\n",
    "  h = Dense(80, name='fc1')(h)\n",
    "  h = LeakyReLU(name='lr5')(h)\n",
    "  h = Dropout(0.1, name='drp2')(h)\n",
    "  h = Dense(40, name='fc2')(h)\n",
    "  h = LeakyReLU(name='lr6')(h)\n",
    "  if num_class == 1:\n",
    "    h = Dense(1, name='fc3')(h)\n",
    "  else:\n",
    "    h = Dense(num_class, activation='softmax', name='fc3')(h)\n",
    "\n",
    "  model = Model(inputs=i, outputs=h)\n",
    "  adamopt = optimizers.adam(lr=0.0001)\n",
    "  if num_class==1:\n",
    "    model.compile(optimizer=adamopt, loss='mse')\n",
    "  else:\n",
    "    model.compile(optimizer=adamopt, loss='categorical_crossentropy', metrics=['acc'])\n",
    "  return model\n",
    "\n",
    "\n",
    "def DCNN2(input_dim, timesteps, num_class):\n",
    "  krl_init = glorot_normal(seed=42)\n",
    "  i = Input(shape=(timesteps, input_dim, 1))\n",
    "  h = Conv2D(filters=32, kernel_size=(1, 3), strides=1, activation='relu', \n",
    "             name='conv2_1', kernel_initializer=krl_init, data_format='channels_last')(i)\n",
    "  h = Conv2D(filters=32, kernel_size=(1, 3), strides=1, activation='relu', \n",
    "             name='conv2_2', kernel_initializer=krl_init, data_format='channels_last')(h)\n",
    "  h = Conv2D(filters=64, kernel_size=(1, 3), strides=1, activation='relu', \n",
    "             name='conv2_3', kernel_initializer=krl_init, data_format='channels_last')(h)\n",
    "  h = Conv2D(filters=64, kernel_size=(1, 3), strides=1, activation='relu', \n",
    "             name='conv2_4', kernel_initializer=krl_init, data_format='channels_last')(h)\n",
    "  s = K.int_shape(h)\n",
    "  h = Reshape((s[1], s[2]*s[3]))(h)\n",
    "  h = Conv1D(filters=16, kernel_size=3, strides=1, activation='relu', \n",
    "             name='conv1_1', kernel_initializer=krl_init)(h)\n",
    "  h = Conv1D(filters=32, kernel_size=3, strides=1, activation='relu', \n",
    "             name='conv1_2', kernel_initializer=krl_init)(h)\n",
    "  h = Flatten(name='flt')(h)\n",
    "  h = Dropout(0.5, name='drp1')(h)\n",
    "  h = Dense(512, activation='relu', name='fc1')(h)\n",
    "  if num_class == 1:\n",
    "    h = Dense(1, name='fc2')(h)\n",
    "  else:\n",
    "    h = Dense(num_class, activation='softmax', name='fc2')(h)\n",
    "  model = Model(inputs=i, outputs=h)\n",
    "  adamopt = optimizers.adam(lr=0.001)\n",
    "  if num_class==1:\n",
    "    model.compile(optimizer=adamopt, loss='mse')\n",
    "  else:\n",
    "    model.compile(optimizer=adamopt, loss='categorical_crossentropy', metrics=['acc'])\n",
    "  return model\n",
    "\n",
    "\n",
    "def CNN_LSTM(input_dim, timesteps, num_class):\n",
    "  i = Input(shape=(timesteps, input_dim))\n",
    "  h = Conv1D(filters=64, kernel_size=5, strides=1, activation='relu', name='conv1')(i)\n",
    "  h = Conv1D(filters=128, kernel_size=3, strides=1, activation='relu', name='conv2')(h)\n",
    "  h = MaxPooling1D(pool_size=2, strides=2, name='mp1')(h)\n",
    "  h = Dropout(0.2, name='drp1')(h)\n",
    "  h = Conv1D(filters=128, kernel_size=13, strides=1, activation='relu', name='conv3')(h)\n",
    "  h = Conv1D(filters=32, kernel_size=7, strides=1, activation='relu', name='conv4')(h)\n",
    "  h = LSTM(32, return_sequences=True, name='lstm1')(h)\n",
    "  h = Flatten(name='flt')(h)\n",
    "  h = Dense(64, activation='relu', name='fc1')(h)\n",
    "  h = Dropout(0.2, name='drp2')(h)\n",
    "  if num_class == 1:\n",
    "    h = Dense(1, name='fc2')(h)\n",
    "  else:\n",
    "    h = Dense(num_class, activation='softmax', name='fc2')(h)\n",
    "  model = Model(inputs=i, outputs=h)\n",
    "  adamopt = optimizers.adam(lr=0.0001)\n",
    "  if num_class==1:\n",
    "    model.compile(optimizer=adamopt, loss='mse')\n",
    "  else:\n",
    "    model.compile(optimizer=adamopt, loss='categorical_crossentropy', metrics=['acc'])\n",
    "  return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QMbzlAeTYAby"
   },
   "source": [
    "# Train models and results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5hoWL8ZKYDvp"
   },
   "outputs": [],
   "source": [
    "mode = [0, 2]\n",
    "model_num = 5 # 1 --> DCNN || 2 --> DCNN2 || 3 --> CNN_LSTM || 4 --> SVM || 5 --> KNN || 6 --> Random Forest\n",
    "print('Loading data...')\n",
    "data, labels = load_data()\n",
    "print('preparing data...')\n",
    "data, labels = prepare_data(data, labels, mode)\n",
    "skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=21)\n",
    "sc = StandardScaler()\n",
    "encoder = OneHotEncoder(sparse=False)\n",
    "counter = 0\n",
    "train_result = []\n",
    "test_result = []\n",
    "train_conf_mat = np.zeros((2, 2))\n",
    "test_conf_mat = np.zeros((2, 2))\n",
    "train_sensitivity = []\n",
    "train_specificity = []\n",
    "test_sensitivity = []\n",
    "test_specificity = []\n",
    "reg_pred = np.zeros((data.shape[0], 1))\n",
    "print('starting 10 fold...')\n",
    "for train_index, test_index in skf.split(data, labels):\n",
    "  counter +=1\n",
    "  print('Fold:', counter)\n",
    "  train_x, test_x = data[train_index], data[test_index]\n",
    "  train_y, test_y = labels[train_index].reshape(-1, 1), labels[test_index].reshape(-1, 1)\n",
    "\n",
    "  for i in range(train_x.shape[0]):\n",
    "    sc.partial_fit(train_x[i])\n",
    "  for i in range(train_x.shape[0]):\n",
    "    train_x[i] = sc.transform(train_x[i])\n",
    "  for i in range(test_x.shape[0]):\n",
    "    test_x[i] = sc.transform(test_x[i])\n",
    "  \n",
    "  if len(mode)==1 and mode[0]==0: # if the problem is regression\n",
    "    num_class = 1\n",
    "  elif (model_num != 4) and (model_num != 5) and (model_num!=6): # if problem is classification and model is Neural Network\n",
    "    train_y = encoder.fit_transform(train_y.reshape(-1,1))\n",
    "    test_y = encoder.transform(test_y.reshape(-1,1))\n",
    "    num_class = train_y.shape[1]\n",
    "  else:\n",
    "    unique_labels = np.unique(train_y)\n",
    "    num_class = unique_labels.shape[0]\n",
    "    if len(mode)==2:\n",
    "      count = 0\n",
    "      for i in unique_labels:\n",
    "        idx = np.where(train_y==i)[0]\n",
    "        train_y[idx] = count\n",
    "        idx = np.where(test_y==i)[0]\n",
    "        test_y[idx] = count\n",
    "        count += 1\n",
    "    \n",
    "\n",
    "  (_, timesteps, num_feats) = train_x.shape\n",
    "  \n",
    "\n",
    "  clear_session()\n",
    "  if model_num==1:\n",
    "    model = DCNN(num_feats, timesteps, num_class)\n",
    "    model.fit(train_x, train_y,\n",
    "              batch_size=5, epochs=10,\n",
    "              shuffle=True, verbose=0,\n",
    "              validation_data=(test_x, test_y))\n",
    "  elif model_num==2:\n",
    "    model = DCNN2(num_feats, timesteps, num_class)\n",
    "    train_x = train_x.reshape(train_x.shape[0], train_x.shape[1], train_x.shape[2], 1)\n",
    "    test_x = test_x.reshape(test_x.shape[0], train_x.shape[1], train_x.shape[2], 1)\n",
    "    model.fit(train_x, train_y,\n",
    "              batch_size=5, epochs=10,\n",
    "              shuffle=True, verbose=0,\n",
    "              validation_data=(test_x, test_y))\n",
    "  elif model_num==3:\n",
    "    model = CNN_LSTM(num_feats, timesteps, num_class)\n",
    "    model.fit(train_x, train_y,\n",
    "              batch_size=64, epochs=10,\n",
    "              shuffle=True, verbose=0,\n",
    "              validation_data=(test_x, test_y))\n",
    "  elif (model_num==4) or (model_num==5) or (model_num==6): # if we want to apply classic methods\n",
    "    train_x = train_x.reshape(train_x.shape[0], -1)\n",
    "    train_x = train_x[:, 0::10]\n",
    "    test_x = test_x.reshape(test_x.shape[0], -1)\n",
    "    test_x = test_x[:, 0::10]\n",
    "    # SVM\n",
    "    if model_num==4:\n",
    "      if num_class==1:\n",
    "        model = SVR(kernel='rbf', gamma='auto', max_iter=250)\n",
    "        model.fit(train_x, train_y)\n",
    "      else:\n",
    "        model = SVC(kernel='rbf', gamma='auto', decision_function_shape='ovr', \n",
    "                    max_iter=250, random_state=21)\n",
    "        model.fit(train_x, train_y)\n",
    "    # KNN\n",
    "    elif model_num==5:\n",
    "      if num_class==1:\n",
    "        model = KNeighborsRegressor(n_neighbors=5, n_jobs=-1)\n",
    "        model.fit(train_x, train_y)\n",
    "      else:\n",
    "        model = KNeighborsClassifier(n_neighbors=5, n_jobs=-1)\n",
    "        model.fit(train_x, train_y)\n",
    "    # Random Forest\n",
    "    elif model_num==6:\n",
    "      if num_class==1:\n",
    "        model = RandomForestRegressor(max_depth=6, n_jobs=-1)\n",
    "        model.fit(train_x, train_y)\n",
    "      else:\n",
    "        model = RandomForestClassifier(max_depth=6, n_jobs=-1)\n",
    "        model.fit(train_x, train_y)\n",
    "\n",
    "  \n",
    "  train_pred = model.predict(train_x).reshape(-1, 1)\n",
    "  test_pred = model.predict(test_x).reshape(-1, 1)\n",
    "      \n",
    "  if len(mode)==1 and mode[0]==0:\n",
    "    train_result.append(mean_squared_error(train_y, train_pred))\n",
    "    test_result.append(mean_squared_error(test_y, test_pred))\n",
    "    reg_pred[test_index] = test_pred.reshape(-1, 1)\n",
    "    print('train mse:', train_result[-1], '-- test mse:', test_result[-1])\n",
    "  else:\n",
    "    if (model_num==4) or (model_num==5) or (model_num==6):\n",
    "      train_result.append(np.sum(train_pred==train_y)/train_y.shape[0])\n",
    "      test_result.append(np.sum(test_pred==test_y)/test_y.shape[0])\n",
    "      print('train acc:', train_result[-1], '-- test acc:', test_result[-1])\n",
    "    else:\n",
    "      train_pred = np.argmax(train_pred, axis=1)\n",
    "      train_result.append(np.sum(train_pred==np.argmax(train_y, axis=1))/train_y.shape[0])\n",
    "      test_pred = np.argmax(test_pred, axis=1)\n",
    "      test_result.append(np.sum(test_pred==np.argmax(test_y, axis=1))/test_y.shape[0])\n",
    "      print('train acc:', train_result[-1], '-- test acc:', test_result[-1])\n",
    "    # if the problem was 2class, calculate sensitivity and specificity\n",
    "    if len(mode)==2:\n",
    "      temp1, temp2, _ = sensitivity_specificity_support(train_y, train_pred, average='binary')\n",
    "      train_sensitivity.append(temp1)\n",
    "      train_specificity.append(temp2)\n",
    "      temp1, temp2, _ = sensitivity_specificity_support(test_y, test_pred, average='binary')\n",
    "      test_sensitivity.append(temp1)\n",
    "      test_specificity.append(temp2)\n",
    "      train_conf_mat += confusion_matrix(train_y, train_pred)\n",
    "      test_conf_mat += confusion_matrix(test_y, test_pred)\n",
    "  \n",
    "\n",
    "train_result = np.array(train_result)\n",
    "test_result = np.array(test_result)\n",
    "print('proposed model train--> mean:', np.mean(train_result), 'std:',np.std(train_result))\n",
    "print('proposed model test--> mean:', np.mean(test_result), 'std:',np.std(test_result))\n",
    "\n",
    "#-------------------------------------------------------------#\n",
    "if model_num==1:\n",
    "  save_dir = './related work results/DCNN/'\n",
    "elif model_num==2:\n",
    "  save_dir = './related work results/DCNN2/'\n",
    "elif model_num==3:\n",
    "  save_dir = './related work results/CNN_LSTM/'\n",
    "elif model_num==4:\n",
    "  save_dir = './related work results/SVM/'\n",
    "elif model_num==5:\n",
    "  save_dir = './related work results/KNN/'\n",
    "elif model_num==6:\n",
    "  save_dir = './related work results/RF/'\n",
    "plt.style.use('seaborn-whitegrid')\n",
    "# train\n",
    "plt.figure(figsize=(10, 8))\n",
    "plt.plot(range(1, train_result.shape[0]+1), train_result, 'b.-', linewidth=2)\n",
    "plt.xticks(range(1, train_result.shape[0]+1))\n",
    "plt.xlabel('Fold number')\n",
    "if len(mode)==1 and mode[0]==0:\n",
    "  plt.ylabel('MSE')\n",
    "  plt.savefig(save_dir+'train_MSE.jpg')\n",
    "else:\n",
    "  plt.ylabel('Accuracy')\n",
    "  if len(mode)==2:\n",
    "    plt.savefig(save_dir+'train_2class('+str(mode[0])+','+str(mode[1])+')_accuracy.jpg')\n",
    "  else:\n",
    "    plt.savefig(save_dir+'train_4class.jpg')\n",
    "# test\n",
    "plt.figure(figsize=(10, 8))\n",
    "plt.plot(range(1, test_result.shape[0]+1), test_result, 'b.-', linewidth=2)\n",
    "plt.xticks(range(1, test_result.shape[0]+1))\n",
    "plt.xlabel('Fold number')\n",
    "if len(mode)==1 and mode[0]==0:\n",
    "  plt.ylabel('MSE')\n",
    "  plt.savefig(save_dir+'test_MSE.jpg')\n",
    "else:\n",
    "  plt.ylabel('Accuracy')\n",
    "  if len(mode)==2:\n",
    "    plt.savefig(save_dir+'test_2class('+str(mode[0])+','+str(mode[1])+')_accuracy.jpg')\n",
    "  else:\n",
    "    plt.savefig(save_dir+'test_4class.jpg')\n",
    "\n",
    "\n",
    "if len(mode)==2:\n",
    "  train_conf_mat /= 10.\n",
    "  test_conf_mat /= 10.\n",
    "  print('train confusion matrix:\\n', train_conf_mat)\n",
    "  print('test confusion matrix:\\n', test_conf_mat)\n",
    "  train_sensitivity = np.array(train_sensitivity)\n",
    "  train_specificity = np.array(train_specificity)\n",
    "  test_sensitivity = np.array(test_sensitivity)\n",
    "  test_specificity = np.array(test_specificity)\n",
    "  # train\n",
    "  plt.figure(figsize=(10, 8))\n",
    "  plt.plot(range(1, train_result.shape[0]+1), train_result, 'r.-', label='Accuracy', linewidth=2)\n",
    "  plt.plot(range(1, train_sensitivity.shape[0]+1), train_sensitivity, 'g.-', label='sensitivity', linewidth=2)\n",
    "  plt.plot(range(1, train_specificity.shape[0]+1), train_specificity, 'b.-', label='specificity', linewidth=2)\n",
    "  plt.xticks(range(1, train_result.shape[0]+1))\n",
    "  plt.xlabel('Fold number')\n",
    "  plt.ylabel('Performance')\n",
    "  plt.legend()\n",
    "  plt.savefig(save_dir+'train_2class('+str(mode[0])+','+str(mode[1])+')_acc_sen_spe.jpg')\n",
    "\n",
    "  plt.figure(figsize=(10, 8))\n",
    "  plt.plot(range(1, train_sensitivity.shape[0]+1), train_sensitivity, 'r.-', label='sensitivity', linewidth=2)\n",
    "  plt.plot(range(1, train_specificity.shape[0]+1), train_specificity, 'b.-', label='specificity', linewidth=2)\n",
    "  plt.xticks(range(1, train_sensitivity.shape[0]+1))\n",
    "  plt.xlabel('Fold number')\n",
    "  plt.ylabel('Performance')\n",
    "  plt.legend()\n",
    "  plt.savefig(save_dir+'train_2class('+str(mode[0])+','+str(mode[1])+')_sen_spe.jpg')\n",
    "  # test\n",
    "  plt.figure(figsize=(10, 8))\n",
    "  plt.plot(range(1, test_result.shape[0]+1), test_result, 'r.-', label='Accuracy', linewidth=2)\n",
    "  plt.plot(range(1, test_sensitivity.shape[0]+1), test_sensitivity, 'g.-', label='sensitivity', linewidth=2)\n",
    "  plt.plot(range(1, test_specificity.shape[0]+1), test_specificity, 'b.-', label='specificity', linewidth=2)\n",
    "  plt.xticks(range(1, test_result.shape[0]+1))\n",
    "  plt.xlabel('Fold number')\n",
    "  plt.ylabel('Performance')\n",
    "  plt.legend()\n",
    "  plt.savefig(save_dir+'test_2class('+str(mode[0])+','+str(mode[1])+')_acc_sen_spe.jpg')\n",
    "\n",
    "  plt.figure(figsize=(10, 8))\n",
    "  plt.plot(range(1, test_sensitivity.shape[0]+1), test_sensitivity, 'r.-', label='sensitivity', linewidth=2)\n",
    "  plt.plot(range(1, test_specificity.shape[0]+1), test_specificity, 'b.-', label='specificity', linewidth=2)\n",
    "  plt.xticks(range(1, test_sensitivity.shape[0]+1))\n",
    "  plt.xlabel('Fold number')\n",
    "  plt.ylabel('Performance')\n",
    "  plt.legend()\n",
    "  plt.savefig(save_dir+'test_2class('+str(mode[0])+','+str(mode[1])+')_sen_spe.jpg')\n",
    "\n",
    "if len(mode)==1 and mode[0]==0:\n",
    "  df = pd.DataFrame(np.concatenate([labels.reshape(-1, 1), reg_pred], axis=1), columns=['True value', 'Predict value'])\n",
    "  df.to_csv(save_dir+'true vs pred.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_jLpH5sOQsm9"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Related works.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
